#!/usr/bin/python
# -*- coding:utf-8 -*-
# @FileName : DL5_test1_1.py
# Author    : myh

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

if __name__ == '__main__':
    num_epochs, lr, batch_size = 10, 0.5, 256
    loss = nn.CrossEntropyLoss(reduction='none')
    train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
    # 我们需要给共享层一个名称，以便可以引用它的参数
    shared = nn.Linear(256, 256)
    net = nn.Sequential(nn.Linear(784, 256), nn.ReLU(),
                        shared, nn.ReLU(),
                        shared, nn.ReLU(),
                        nn.Linear(256, 10))

    trainer = torch.optim.SGD(net.parameters(), lr=lr)
    d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)


